{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Import Packages\n",
    "Modify the system path and load the corresponding packages and functions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "ROOT = str(Path(\"__file__\").resolve().parents[1])\n",
    "sys.path.append(ROOT)\n",
    "import torch\n",
    "import argparse\n",
    "import os.path as osp\n",
    "from mmcv import Config\n",
    "from trademaster.utils import replace_cfg_vals\n",
    "from trademaster.nets.builder import build_net\n",
    "from trademaster.environments.builder import build_environment\n",
    "from trademaster.datasets.builder import build_dataset\n",
    "from trademaster.agents.builder import build_agent\n",
    "from trademaster.optimizers.builder import build_optimizer\n",
    "from trademaster.losses.builder import build_loss\n",
    "from trademaster.trainers.builder import build_trainer\n",
    "from trademaster.transition.builder import build_transition"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Load Configs\n",
    "Load default config from the folder `configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
    "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"portfolio_management\", \"portfolio_management_dj30_eiie_eiie_adam_mse.py\"),\n",
    "                    help=\"download datasets config file path\")\n",
    "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
    "parser.add_argument(\"--test_style\", type=str, default='-1')\n",
    "args = parser.parse_args([])\n",
    "cfg = Config.fromfile(args.config)\n",
    "task_name = args.task_name\n",
    "\n",
    "cfg = replace_cfg_vals(cfg)\n",
    "# update test style\n",
    "cfg.data.update({'test_style': args.test_style})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py): {'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'tech_indicator_list': ['zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 10, 'initial_amount': 100000, 'transaction_cost_pct': 0.001, 'test_style_path': 'data/portfolio_management/dj30/DJI_label_by_DJIindex_3_24_-0.25_0.25.csv', 'test_style': '-1'}, 'environment': {'type': 'PortfolioManagementEIIEEnvironment'}, 'agent': {'type': 'PortfolioManagementEIIE', 'memory_capacity': 1000, 'gamma': 0.99, 'policy_update_frequency': 500}, 'trainer': {'type': 'PortfolioManagementEIIETrainer', 'epochs': 10, 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse', 'if_remove': True}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'EIIEConv', 'input_dim': None, 'output_dim': 1, 'time_steps': 10, 'kernel_size': 3, 'dims': [32]}, 'cri': {'type': 'EIIECritic', 'input_dim': None, 'action_dim': None, 'output_dim': 1, 'time_steps': None, 'num_layers': 1, 'hidden_size': 32}, 'transition': {'type': 'Transition'}, 'task_name': 'portfolio_management', 'dataset_name': 'dj30', 'net_name': 'eiie', 'agent_name': 'eiie', 'optimizer_name': 'adam', 'loss_name': 'mse', 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse'}"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Build Dataset\n",
    "Build datasets from cfg defined above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = build_dataset(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.datasets.portfolio_management.dataset.PortfolioManagementDataset at 0x7f375ae6fe50>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Build Reinforcement Learning Environments\n",
    "Build environments based on cfg and previously-defined dataset\n",
    "\n",
    "A style-test is provided as an option to test the algorithm's performance under different market conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n",
    "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n",
    "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n",
    "if task_name.startswith(\"style_test\"):\n",
    "        test_style_environments = []\n",
    "        for i, path in enumerate(dataset.test_style_paths):\n",
    "            test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n",
    "                                                                                    style_test_path=path,\n",
    "                                                                                    task_index=i)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.portfolio_management.eiie_environment.PortfolioManagementEIIEEnvironment at 0x7f37635ad550>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.portfolio_management.eiie_environment.PortfolioManagementEIIEEnvironment at 0x7f37635ad110>"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<trademaster.environments.portfolio_management.eiie_environment.PortfolioManagementEIIEEnvironment at 0x7f375ae69b50>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_environment"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Build Net \n",
    "Update information about the state and action dimension in the config and create nets and optimizer for EIIE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_dim = train_environment.action_dim # 29\n",
    "state_dim = train_environment.state_dim # 11\n",
    "input_dim = len(train_environment.tech_indicator_list)\n",
    "time_steps = train_environment.time_steps\n",
    "\n",
    "cfg.act.update(dict(input_dim=input_dim, time_steps=time_steps))\n",
    "cfg.cri.update(dict(input_dim=input_dim, action_dim= action_dim, time_steps=time_steps))\n",
    "\n",
    "act = build_net(cfg.act)\n",
    "cri = build_net(cfg.cri)\n",
    "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n",
    "cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Build Loss\n",
    "Build loss from config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = build_loss(cfg)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: Build Transition\n",
    "Build transition from config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "transition = build_transition(cfg)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 8: Build Agent\n",
    "Build agent from config and detect device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "agent = build_agent(cfg, default_args=dict(action_dim=action_dim,\n",
    "                                               state_dim=state_dim,\n",
    "                                               time_steps = time_steps,\n",
    "                                               act=act,\n",
    "                                               cri=cri,\n",
    "                                               act_optimizer=act_optimizer,\n",
    "                                               cri_optimizer = cri_optimizer,\n",
    "                                               criterion=criterion,\n",
    "                                               transition = transition,\n",
    "                                               device = device))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 9: Build Trainer\n",
    "Build trainer from config and create work directionary to save the result, model and config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Arguments Remove work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse\n"
     ]
    }
   ],
   "source": [
    "if task_name.startswith(\"style_test\"):\n",
    "    trainers = []\n",
    "    for env in test_style_environments:\n",
    "        trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
    "                                                                valid_environment=valid_environment,\n",
    "                                                                test_environment=env,\n",
    "                                                                agent=agent,\n",
    "                                                                device=device)))\n",
    "else:\n",
    "    trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
    "                                                    valid_environment=valid_environment,\n",
    "                                                    test_environment=test_environment,\n",
    "                                                    agent=agent,\n",
    "                                                    device=device,\n",
    "                                                    ))\n",
    "work_dir = os.path.join(ROOT, cfg.trainer.work_dir)\n",
    "\n",
    "if not os.path.exists(work_dir):\n",
    "    os.makedirs(work_dir)\n",
    "cfg.dump(osp.join(work_dir, osp.basename(args.config)))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 10: Train the Trainer\n",
    "Train the trainer based on the config and get results from workdir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Episode: [1/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|  180.543176%  |   0.001605  |  0.007575  |   0.645235   |   1.688276   |    4.176283   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode: [1/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|   9.519411%   |   0.001804  |  0.021395  |   0.361807   |   0.406514   |    0.529803   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode Reward Sum: 0.090932\n",
      "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00001.pth\n",
      "Train Episode: [2/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|  180.534263%  |   0.001605  |  0.007575  |   0.645216   |   1.688270   |    4.176311   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode: [2/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|   9.520920%   |   0.001804  |  0.021394  |   0.361798   |   0.406550   |    0.529857   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode Reward Sum: 0.090945\n",
      "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00002.pth\n",
      "Train Episode: [3/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|  180.519361%  |   0.001605  |  0.007574  |   0.645198   |   1.688229   |    4.176275   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode: [3/10]\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "|   9.521613%   |   0.001804  |  0.021394  |   0.361794   |   0.406567   |    0.529882   |\n",
      "+---------------+-------------+------------+--------------+--------------+---------------+\n",
      "Valid Episode Reward Sum: 0.090952\n",
      "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00003.pth\n",
      "Train Episode: [4/10]\n"
     ]
    }
   ],
   "source": [
    "if task_name.startswith(\"train\"):\n",
    "    trainer.train_and_valid()\n",
    "    trainer.test()\n",
    "    print(\"train end\")\n",
    "elif task_name.startswith(\"test\"):\n",
    "    trainer.test()\n",
    "    print(\"test end\")\n",
    "elif task_name.startswith(\"style_test\"):\n",
    "    daily_return_list = []\n",
    "    for trainer in trainers:\n",
    "        daily_return_list.extend(trainer.test())\n",
    "    print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n",
    "    print(\"style test end\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HFT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
